{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1\"\n",
    "\n",
    "import gc\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "from fusion_bench.method.pruning.wanda_utils.eval import eval_ppl\n",
    "from fusion_bench.models.modeling_losparse_llama import LoSparseLlamaForCausalLM\n",
    "from fusion_bench.utils import print_parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fusion_bench.models.modeling_losparse_llama.modeling_losparse_llama import (\n",
    "    LoSparseLinear,\n",
    "    LoSparseLlamaForCausalLM,\n",
    ")\n",
    "\n",
    "\n",
    "def model_eval_ppl(model_path):\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        model_path,\n",
    "        torch_dtype=torch.float16,\n",
    "        low_cpu_mem_usage=True,\n",
    "        device_map=\"auto\",\n",
    "    )\n",
    "    print_parameters(model)\n",
    "    model.seqlen = model.config.max_position_embeddings\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        result = eval_ppl(model, tokenizer)\n",
    "\n",
    "    print(f\"PPL for {model_path}: {result}\")\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5fae0247c9214cb7bd62128973bda66c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aafc1b8d8c204023873acc7cc8c6b81e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/Llama-2-13b-hf: 4.573723793029785\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "4.573723793029785"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/Llama-2-13b-hf\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Magnitude"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e2daa7211596467cad3ae84c9e8e1007",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 13.02B || all params: 13.02B || trainable%: 100.0000\n",
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "18782a7b4e5a4836a29908731e58fd48",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n",
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/magnitude/unstructured/0.5: 5.9772844314575195\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "5.9772844314575195"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/magnitude/unstructured/0.5\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2609b57028d6487581df67993d2d61b3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "966e185aa91a4b4aa49c16be237953ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/magnitude/unstructured/0.6: 9.907220840454102\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "9.907220840454102"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/magnitude/unstructured/0.6\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "99871d998e604df7bd7076686c8bb1dd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "de801df3e0954f47821023b7d70a74e8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n",
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/magnitude/unstructured/0.7: 408.7518310546875\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "408.7518310546875"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/magnitude/unstructured/0.7\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b41e211d8e184d649b6a33d3fe0f4b59",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d81f5b56465644b5b8c0393256e766ee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n",
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/magnitude/semistructured/2_4: 8.319043159484863\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "8.319043159484863"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/magnitude/semistructured/2_4\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8734366121164e4993e36c761a999daa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "16e4308a708f4c07a797ff723f602670",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n",
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/magnitude/semistructured/4_8: 6.7590107917785645\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "6.7590107917785645"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/magnitude/semistructured/4_8\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sparselo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Magnitude"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "019ad56d1b98487daef5f2f54753dc19",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "45e41da06ddc4dc5b359879b1ddc1aba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n",
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/sparselo/magnitude/unstructured/0.5: 5.726884841918945\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "5.726884841918945"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/sparselo/magnitude/unstructured/0.5\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7cafc189992945aa98ac789e7a01018a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a630777ee0404f508936f5d54b117ad0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n",
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/sparselo/magnitude/unstructured/0.6: 8.883753776550293\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "8.883753776550293"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/sparselo/magnitude/unstructured/0.6\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "22cce5ca10a84ddb83a98b30e9e494bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "152669ef46584258941f93aee8783e23",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n",
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/sparselo/magnitude/unstructured/0.7: 163.96058654785156\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "163.96058654785156"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/sparselo/magnitude/unstructured/0.7\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1f97664a11824c9fa9eea5fd60513f1f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "692f0ae75c0f4025add1e6400780b2cb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n",
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/sparselo/magnitude/semistructured/2_4: 8.864727973937988\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "8.864727973937988"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/sparselo/magnitude/semistructured/2_4\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e9b0d422f61d464c833b54126222597b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ebd791c7e5b1420ca89b7fadb4676546",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n",
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/sparselo/magnitude/semistructured/4_8: 6.580754280090332\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "6.580754280090332"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/sparselo/magnitude/semistructured/4_8\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Iter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Magnitude"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "41051afb8f554a6194998ec1acbd8ccb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9996cbfc013940caa960d1279a67ec0e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n",
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/iterative_sparselo/magnitude/unstructured/0.5: 5.648659706115723\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "5.648659706115723"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/iterative_sparselo/magnitude/unstructured/0.5\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f4275ae332554505951c90a185cb916e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 13.52B || all params: 13.52B || trainable%: 100.0000\n",
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "23d026b91a964335b0159e026fd1ee5d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/iterative_sparselo/magnitude/unstructured/0.6: 8.832003593444824\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "8.832003593444824"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/iterative_sparselo/magnitude/unstructured/0.6\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62baeddafe4f412fab4d33b95d5acf01",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 13.52B || all params: 13.52B || trainable%: 100.0000\n",
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a1effdfa33444993996f773fadd773c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n",
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/iterative_sparselo/magnitude/unstructured/0.7: 99.2710189819336\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "99.2710189819336"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/iterative_sparselo/magnitude/unstructured/0.7\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "affc4e3fbfb341e2a456bfc49eb2cb96",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab4b656426ca456890f9a4e1301ce5a5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/iterative_sparselo/magnitude/semistructured/2_4: 7.764371395111084\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "7.764371395111084"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/iterative_sparselo/magnitude/semistructured/2_4\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d53b1eac4ccd475b997e1ae6011ef51f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating on wikitext2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n",
      "Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /data0/users/tanganke/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Wed Aug 28 13:14:33 2024).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "182abc827e3347aab19134602d95dd19",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating samples:   0%|          | 0/128 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nsamples 83\n",
      "sample 0\n",
      "sample 50\n",
      "PPL for /data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/iterative_sparselo/magnitude/semistructured/4_8: 6.39985466003418\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "6.39985466003418"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_eval_ppl(\n",
    "    \"/data0/users/tanganke/projects/fusion_bench/outputs/llama-13b/iterative_sparselo/magnitude/semistructured/4_8\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fusionbench",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
